#! /usr/bin/python3

# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
# SPDX-License-Identifier: GPL-2.0-only

"""
Test-build a kernel module against all kernel versions in a range of refs, or single git ref.

e.g. mass-build-test ttkmd linux v4.18 v6.3

Stops on first build failure.
"""

import argparse
from functools import lru_cache
import os
from pathlib import Path
import re
import subprocess
import sys
from typing import Optional, List, NamedTuple, Dict, Tuple

@lru_cache(1)
def make_j() -> str:
    core_count = len(os.sched_getaffinity(os.getpid()))
    return f"-j{core_count}"

# This merges stdout and stderr and returns them if successful.
# On failure CalledProcessError.stdout has the merged data.
def subprocess_check_output_merged(args: List[str], working_dir: Path) -> bytes:
    return subprocess.check_output(args, cwd=working_dir, stderr=subprocess.STDOUT)

def prepare_kernel(kernel_dir: Path, ver: str) -> bool:
    subprocess_check_output_merged(["git", "clean", "-dfx"], kernel_dir)
    subprocess_check_output_merged(["git", "reset", "--hard"], kernel_dir)
    subprocess_check_output_merged(["git", "checkout", ver], kernel_dir)
    subprocess_check_output_merged(["make", make_j(), "defconfig"], kernel_dir)
    subprocess_check_output_merged(["make", make_j(), "modules_prepare"], kernel_dir)

def build_module(module_dir: Path, kernel_dir: Path) -> None:
    subprocess_check_output_merged(["make", make_j(), "clean"], module_dir)

    # We're building against an incompletely-built kernel tree, so modpost can't find symbol versions.
    subprocess_check_output_merged(["make", make_j(), f"KDIR={kernel_dir.absolute()}", "KBUILD_MODPOST_WARN=1"], module_dir)

class KernelVersion(NamedTuple):
    major: int
    minor: int
    point: int

    def major_minor(self) -> Tuple[int,int]:
        return (self.major, self.minor)

    def __str__(self) -> str:
        return "v" + ".".join(map(str, self if self.point != 0 else self.major_minor()))

version_tag_re = re.compile(rb"v(\d+)\.(\d+)(?:\.(\d+))?")

def parse_version_tag(tag: bytes) -> Optional[KernelVersion]:
    m = version_tag_re.fullmatch(tag)
    if m is None:
        return None

    major = int(m[1])
    minor = int(m[2])
    point = int(m[3]) if m[3] else 0

    return KernelVersion(major, minor, point)

def insert_version_if_latest(test_versions: Dict[Tuple[int,int], KernelVersion], version: KernelVersion) -> None:
    key = (version[0], version[1])

    if test_versions.get(key, (0, 0, 0)) < version:
        test_versions[key] = version

def convert_ref_to_kernel_version(kernel_dir: Path, ref: str) -> Optional[KernelVersion]:
    try:
        tag_output = subprocess.check_output(["git", "tag", "--points-at", ref, "--sort=creatordate"],
                                             cwd=kernel_dir, stderr=subprocess.PIPE)
    except subprocess.CalledProcessError as e:
        raise RuntimeError(f"Failed to list tags pointing to {ref} in {kernel_dir}.\n" + e.stderr.decode(errors='replace'))

    return next(filter(None, map(parse_version_tag, tag_output.splitlines())), None)

def version_tags_between(kernel_dir: Path, start: str, end: str, use_latest_point: bool) -> List[str]:
    try:
        tag_output = subprocess.check_output(["git", "tag", "--contains", start, "--no-contains", end, "--sort=creatordate"],
                                             cwd=kernel_dir, stderr=subprocess.PIPE)
    except subprocess.CalledProcessError as e:
        raise RuntimeError(f"Failed to list tags between {start} and {end} in {kernel_dir}.\n" + e.stderr.decode(errors='replace'))

    test_versions = {} # (major,minor) => (major,minor,point). Retains only the highest point.

    interval_versions = filter(None, map(parse_version_tag, tag_output.splitlines()))

    if use_latest_point:
        for version in interval_versions:
            insert_version_if_latest(test_versions, version)
    else:
        test_versions = { v.major_minor(): v for v in interval_versions if v.point == 0 }

    # If end is a kernel version like "v6.5" or any ref that happens to tagged as a kernel version,
    # then include that kernel version in the list of versions and resolve to latest point if requested.

    end_version = convert_ref_to_kernel_version(kernel_dir, end)
    if end_version is not None:
        if use_latest_point:
            try:
                tag_output = subprocess.check_output(["git", "tag", "--contains", end, "--sort=creatordate"],
                                                     cwd=kernel_dir, stderr=subprocess.PIPE)
            except subprocess.CalledProcessError as e:
                raise RuntimeError(f"Failed to list tags starting from {end} in {kernel_dir}.\n" + e.stderr.decode(errors='replace'))

            def matching_version(v: KernelVersion) -> bool:
                return v is not None and v.major_minor() == end_version.major_minor()

            end_version = max(filter(matching_version, map(parse_version_tag, tag_output.splitlines())))

        test_versions[end_version.major_minor()] = end_version

    return [ str(v) for v in sorted(test_versions.values()) ]

def test_one_build(module_dir: Path, kernel_dir: Path, ver: str) -> None:
    prepare_kernel(kernel_dir, ver)
    build_module(module_dir, kernel_dir)

def main() -> int:
    ap = argparse.ArgumentParser(epilog=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
    ap.add_argument("module_dir", type=Path)
    ap.add_argument("kernel_dir", type=Path)
    ap.add_argument("start_ver", help="git ref of first version to build.")
    ap.add_argument("end_ver", nargs="?", help="Last version to build (inclusive). If not present, only builds start_ver.")
    ap.add_argument("--skip-builds", action='store_true', help=argparse.SUPPRESS)
    args = ap.parse_args()

    if args.end_ver is None:
        versions = [args.start_ver]
    else:
        versions = version_tags_between(args.kernel_dir, args.start_ver, args.end_ver, True)

    print("Testing versions:", ", ".join(versions))

    for v in versions:
        try:
            print(v)
            if not args.skip_builds:
                test_one_build(args.module_dir, args.kernel_dir, v)
        except subprocess.CalledProcessError as e:
            print(f"Build failed for {v}.")
            command = " ".join(map(str, e.cmd))
            print(f"{command} exited with code {e.returncode}.")
            if e.stdout:
                print("Its output follows:")
                print("=" * 72)
                print(e.stdout.decode(errors='replace'))
                print("=" * 72)
            return 1

    print("All builds passed.")

    return 0

sys.exit(main())
